Skip to content

Conversation

@tmcgrath325
Copy link
Owner

This adds a keyword argument, objective to the overlap function. A way to pass objective functions to through to be applied to bounds calculations (e.g. gogma_align -> branchbound -> generic_overlap -> passed objective) has also been added.

Further, objective functions for multi-GMMs can be supplied in a dictionary similar to the format for interactions:

# Allow features to be pulled closer or pushed further away (whichever is more favorable) by 0.5
relaxed_overlap(distsq, s, w) = gaussian_overlap(max(0, distsq - sign(w) * 0.5), s, w)

interactions = Dict(
    (:positive, :negative) =>  1.0,
    (:positive, :positive) => -1.0,
    (:negative, :negative) => -1.0,
    (:steric, :steric) => -1.0,
)
objective = Dict(
    (:positive, :negative) =>  gaussian_overlap, # the default objective function
    (:positive, :positive) => gaussian_overlap,
    (:negative, :negative) => gaussian_overlap,
    (:steric, :steric) => relaxed_overlap,
)

res = gogma_align(mgmmx, mgmmy; interactions=interactions, maxsplits=5e3, nextblockfun=randomblock, objective=objective)

Copy link
Collaborator

@timholy timholy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite exciting! I left a couple of specific comments below. Here let me go over the callsite-specialization issue that appears below and may be something to keep in mind more generally.

Let's compare two implementations of a simple "depot" function:

julia> function dotrig1(x, funcname)
           f = funcname == "sin" ? sin :
               funcname == "cos" ? cos :
               funcname == "tan" ? tan :
               funcname == "sec" ? sec : error("$funcname not supported")
           return f(x)
       end

julia> function dotrig2(x, funcname)
           funcname == "sin" && return sin(x)
           funcname == "cos" && return cos(x)
           funcname == "tan" && return tan(x)
           funcname == "sec" && return sec(x)
           error("$funcname not supported")
       end

Now compare:

using Cthulhu
@descend iswarn=true dotrig1/4, "sin")
@descend iswarn=true dotrig2/4, "sin")

You'll note that dotrig1 is inferred to return Any whereas dotrig2 is inferred to return Float64. What's happening here is callsite specialization: for dotrig1, there's a single place in the code that makes the f(x) call, and so the compiler has to allow that call to be generic: it inserts code that says, "OK, what is f? Can I find a compiled method for that f and that ::typeof(x)? If not, compile it; then run it." There's a lot of "state" that might affect the return type of f(x) and at a certain point Julia's type-inference just gives up. (It has to: type-inference is subject to the halting problem, so it has to have heuristics that self-terminate.)

In contrast, with dotrig2, there are 4 separate call sites for f(x), each corresponding to a different f. This allows the compiler to specialize each one of those sites differently, and since it knows the f with certainty it can fully specialize each one, precompile all the calls, and thus execute the function just by jumping to a specific known-in-advance memory address that gets hardwired into the compiled code. You can't get more efficient than that (well, with inlining you might, but the good news Julia will even do that if appropriate).

A related issue: if you have a list of functions that you want to apply to a lot of different variables, then

for x in varlist
    for f in funclist
        mysum += f(x)
    end
end

will be really bad, because there isn't a single call that can be predicted in advance: each and every one of those O(m*n) calls has to be individually analyzed ("what's f?"). In contrast,

for f in funclist
    mysum += apply_to_list(f, xlist)
end

@noinline apply_to_list(f::F, xlist) where F = sum(f, xlist)

forces Julia to specialize apply_to_list for each different f: Julia may not be able to predict which f will come out of flist, but once it figures that out it calls one of many different compiled versions of apply_to_list, each of which is specialized to a particular f and so is highly efficient for iterating over xlist and aggregating the output. In other words, this is O(m) rather than O(m*n) in its runtime-dispatch performance.

This is an example of the function barrier trick. The f::F ... where F would ordinarily do nothing, but Julia has special heuristics for function- and type-arguments that avoid specialization in some cases (the number of Real subtypes is presumably limited in practice, but the number of Function subtypes is effectively unbounded so to avoid creating an infinite amount of compiled code Julia just decides not to specialize some code). In this case you may want to disable those heuristics, and adding a type-parameter achieves that. The @noinline is probably unnecessary with modern Julia, but at one point in Julia's development it was important to prevent inlining from defeating your efforts. I'm old-school so I insert these as a precaution.

`σᵣ` and `σₜ` represent the sizes of the rotation and translation uncertainty regions.
The `objective` should be a function that takes the squared distance between the means of two `IsotropicGaussian`s, the sum of their variances, and the product of their amplitudes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should spell out the exact arguments. From this I infer that the argument order is objective(Δμ, σ²sum, ϕprod), but best to be explicit.

Also, does objective need to satisfy certain requirements? E.g., does it have to be monotonic? (Lennard-Jones comes to mind.) Do you need to specialize any methods for your objective function? E.g., estimate_lower_bound(::typeof(lennardjones), Δμ, σ²sum, ϕprod). If there is an API that the user-supplied function needs to satisfy, it should be spelled out.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that Lennard-Jones wouldn't work here without some careful thought -- maybe handling the attractive and repulsive terms separately.

The assumption is that the objective needs to be monotonically decreasing with increasing Δμ² (which is absolutely important to be clear about).

Comment on lines +130 to +131
obj = !isdict ? objective : (haskey(objective, (key1,key2)) ? objective[(key1,key2)] : objective[(key2,key1)])
lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective = obj, kwargs...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If profiling reveals this line to be a bottleneck, here's one way that might help a tiny bit:

Suggested change
obj = !isdict ? objective : (haskey(objective, (key1,key2)) ? objective[(key1,key2)] : objective[(key2,key1)])
lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective = obj, kwargs...)
if !isdict
lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective, kwargs...) # this call *might* be inferrable
else
obj = get(objective, (key1, key2), nothing)
if obj === nothing
obj = get(objective, (key2, key1), nothing)
end
if obj !== nothing
lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective = obj, kwargs...) # this call is not
end

The important part of this is the !isdict case, which gives Julia a chance to "pass down" knowledge of typeof(objective) from the input arguments to the call to generic_bounds which can be specialized for objective=objective. For the second one, if you know (or can compute) the return type of generic_bounds then you might want to add a type-annotation, e.g., generic_bounds(args...; kwargs...)::Tuple{T,T} so that lb, ub has known type even if Julia can't infer its way through the entire call.

The other part of the change is vastly less important, but exploits the fact that

haskey(dict, a) ? dict[a] : nothing

involves looking up the key a twice, whereas

get(dict, a, nothing)

only looks up the key a once. Note that while nothing is a conventional default, if nothing is in fact a legitimate user-supplied value in the dictionary, you can ensure there's no ambiguity about whether the key was present in the dictionary as follows:

struct NotFound end    # a private type for internal use only
const notfound = NotFound()

x = get(dict, key, notfound)
if x !== notfound
    ...

Then there's no way that notfound was retrieved from dict (or if it is, it's clearly the user's fault).

You probably have enough places in your code that might check both orders of the keys that it might be worth splitting the double-get block into a utility function.

randtform = AffineMap(RotationVec*0.1rand(3)...), SVector{3}(0.1*rand(3)...))

# allowing some fuzziness in the distance
relaxed_overlap(distsq, s, w) = gaussian_overlap(max(0, distsq - sign(w) * 0.5), s, w)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@tmcgrath325
Copy link
Owner Author

Thanks for the feedback! I was worried about runtime dispatch issues, and the function barrier trick seems really helpful here.

@timholy
Copy link
Collaborator

timholy commented Jun 6, 2024

Another option that might be worth taking seriously would be stashing the interaction overlap functions in

struct InteractionOverlap{FS,FH,FI,FP}
    steric::FS
    hydrophobic::FH
    ionic::FI
    polar::FP
end

I'm not sure that's a good idea, but it certainly would make everything inferrable.

But profiling is the real decider, here. "Strategic non-inferrability" can be a good thing and may not hurt your performance while also letting you simplify your code. Mostly it's about knowing what bottlenecks you have and what tricks you have at your disposal for fixing them. It's usually not worth fixing inference problems unless they are affecting performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants